module aes(
    input           clk,
    input           rst_n,

    input           aes_req_vld,
    input           aes_req_sof,
    input           aes_req_eof,
    input   [255:0] aes_req_data,
    output          aes_req_rdy,

    output          aes_rsp_vld,
    output  [255:0] aes_rsp_data,
    input           aes_rsp_rdy,

    input           csr_aes_work_mode,  //0:ecb 1:bcb
    input           csr_aes_mode,       //0:encode 1:decode  
    input   [255:0] csr_aes_key,
    input   [1:0]   csr_aes_key_mode,   //00:128-bits  01:192-bits  10:
    input   [127:0] csr_aes_init_vector
);

//--------------------------------------------
wire    [15:0][255:0]   sbox;
wire    [15:0][255:0]   de_sbox;

//extend key var
wire    [9:0][255:0]    key_extend;
wire                    extend_done;

//key add var
wire    [127:0]         data_key_add;
wire    [15:0][7:0]     data_key_add_matrix;

//line shifter var
wire    [15:0][7:0]     data_line_shifter_matrix;
reg     [15:0][7:0]     data_line_shifter_matrix_d;
wire    [127:0]         data_line_shifter_vector;
//col mixer var
wire    [15:0][7:0]     data_col_mix_matrix;
wire    [127:0]         data_col_mix_vector;

//aes control var
reg     [127:0]         aes_req_data_d;
wire                    aes_receive;
reg                     aes_busy;
reg     [4:0]           operate_cnt;
wire    [3:0]           round_cnt;
reg     [127:0]         round_key;
wire    [127:0]         round_din;
reg     [127:0]         last_data;

//output
reg                     aes_rsp_vld;
reg     [127:0]         aes_rsp_data;
//--------------------------------------------

assign aes_receive = aes_req_vld & aes_req_rdy;
assign aes_req_rdy = ~aes_busy;

aes_sbox x_aes_sbox(
    .sbox(sbox),
    .de_sbox(de_sbox)
);


always @(posedge clk or negedge rst_n)begin
    if(!rst_n)
        aes_busy <= 1'b0;
    else if(aes_receive)
        aes_busy <= 1'b1;
    else if(aes_rsp_vld & aes_rsp_rdy)
        aes_busy <= 1'b0;
end


aes_key_extend x_aes_key_extend(
    .clk(clk),
    .rst_n(rst_n),

    .key_vld(aes_busy),
    .key(csr_aes_key[255:0]),
    .key_mode(csr_aes_key_mode[1:0]),
    .sbox(sbox),

    .key_extend(key_extend),
    .extend_done(extend_done)
);

always @(posedge clk or negedge rst_n)begin
    if(!rst_n)
        operate_cnt[4:0] <= 5'b0;
    else if(extend_done)
        operate_cnt[4:0] <= 5'b0;
    else
        operate_cnt[4:0] <= operate_cnt[4:0];
end

assign round_cnt[3:0] = operate_cnt[4:1];

always @(*)begin
    case(round_cnt[3:0])
        4'd1:   round_key[127:0] = key_extend[0][127:0];

        4'd2:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[1][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ? {key_extend[1][63:0],key_extend[0][191:128]} :
                                                                            key_extend[0][255:128];

        4'd3:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[2][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ?  key_extend[1][191:64] :
                                                                            key_extend[1][127:0];

        4'd4:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[3][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ?  key_extend[2][127:0] :
                                                                            key_extend[1][255:128];

        4'd5:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[4][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ? {key_extend[3][63:0],key_extend[2][191:128]} :
                                                                            key_extend[2][127:0];

        4'd6:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[5][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ?  key_extend[3][191:64] :
                                                                            key_extend[2][255:128];

        4'd7:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[6][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ?  key_extend[4][127:0] :
                                                                            key_extend[3][127:0];

        4'd8:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[7][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ? {key_extend[5][63:0],key_extend[4][191:128]} :
                                                                            key_extend[4][255:128];

        4'd9:   round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[8][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ?  key_extend[5][191:64] :
                                                                            key_extend[5][127:0];

        4'd10:  round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_128) ?  key_extend[9][127:0] :
                                   (csr_aes_key_mode[1:0]==KEY_MODE_192) ?  key_extend[6][127:0] :
                                                                            key_extend[5][255:128];

        4'd11:  round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_192) ? {key_extend[7][63:0],key_extend[6][191:128]} :
                                                                            key_extend[6][127:0];

        4'd12:  round_key[127:0] = (csr_aes_key_mode[1:0]==KEY_MODE_192) ?  key_extend[7][191:64] :
                                                                            key_extend[6][255:128];
        4'd13:  round_key[127:0] = key_extend[7][127:0];
        4'd14:  round_key[127:0] = key_extend[7][255:128];
        default:round_key[127:0] = csr_aes_key[127:0];
    endcase
end

always @(posedge clk or negedge rst_n)begin
    if(!rst_n)
        aes_req_data_d[127:0] <= 128'b0;
    else if(aes_receive)
        aes_req_data_d[127:0] <= aes_req_data[127:0];
end

assign round_din[127:0] = round_cnt[3:0]==4'b0 ? aes_req_data_d[127:0] : data_col_mix_vector[127:0];


aes_key_add x_aes_key_add(
    
    .csr_aes_work_mode(csr_aes_work_mode),

    .key(round_key[127:0]),
    .din(round_din[127:0]),
    .vector(last_data[127:0]),

    .dout(data_key_add[127:0])
);


generate
    genvar i;
    for(i=0;i<16;i++)begin:loop_key_add
        assign data_key_add_matrix[i] = data_key_add[8*i-1-:8];
    end
endgenerate

aes_line_shifter_sbox x_aes_line_shifter_sbox(
    .sbox(sbox),
    .matrix(data_key_add_matrix),

    .matrix_s(data_line_shifter_matrix)
);


always @(posedge clk or negedge rst_n)begin
    if(!rst_n)
        for(integer i=0;i<16;i++)
            data_line_shifter_matrix_d[i] <= 8'h0;
    else
        for(integer i=0;i<16;i++)
            data_line_shifter_matrix_d[i] <= data_line_shifter_matrix[i];
end

generate
genvar j;
    for(j=0;j<16;j++)begin:loop_shifter_verctor
        assign data_line_shifter_vector[8*j-1-:8] = data_line_shifter_matrix_d[j];
    end
endgenerate


aes_col_mixer x_aes_col_mixer(

    .clk(clk),
    .rst_n(rst_n),

    .mode(csr_aes_mode),   //0:encode  1:decode
    .matrix(data_line_shifter_matrix_d),

    .matrix_m(data_col_mix_matrix)
);


generate
genvar k;
    for(k=0;k<16;k++)begin:loop_col_mix_vector
        assign data_col_mix_vector[k*8-1-:8] = data_col_mix_matrix[k];
    end
endgenerate

always @(posedge clk or negedge rst_n)begin
    if(!rst_n)
        aes_rsp_vld <= 1'b0;
    else if(aes_rsp_vld & aes_rsp_rdy)
        aes_rsp_vld <= 1'b0;
    else if((csr_aes_key_mode[1:0]==KEY_MODE_128 && round_cnt[3:0]==4'd10) | 
            (csr_aes_key_mode[1:0]==KEY_MODE_192 && round_cnt[3:0]==4'd12) | 
            (csr_aes_key_mode[1:0]==KEY_MODE_192 && round_cnt[3:0]==4'd14))
        aes_rsp_vld <= 1'b1;
end


assign aes_rsp_data[127:0] = data_line_shifter_vector[127:0];

always @(posedge clk or negedge rst_n)begin
    if(!rst_n)
        last_data[127:0] <= 128'b0;
    else if(aes_req_sof & aes_receive)
        last_data[127:0] <= csr_aes_init_vector[127:0];
    else if(aes_rsp_vld & aes_rsp_rdy)
        last_data[127:0] <= aes_rsp_data[127:0];
end



endmodule
